{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reward Mean: 0.0351\n",
      "Progress: 1000 of 2394\n",
      "Progress: 2000 of 2394\n",
      "Written training data to /home/berscheid/Documents/data/pushing/gen-push/train.csv\n",
      "Written test data to /home/berscheid/Documents/data/pushing/gen-push/test.csv\n"
     ]
    }
   ],
   "source": [
    "from generate_input_planar_pose import GenerateInputPlanarPose\n",
    "\n",
    "directory = os.path.expanduser('~/Documents/data/')\n",
    "\n",
    "generator = GenerateInputPlanarPose([\n",
    "    directory + 'pushing/cylinder-cube-1.db',\n",
    "], test_files=[\n",
    "#    directory + 'all-1/all-1.db',\n",
    "], output_folder='gen-push/')\n",
    "\n",
    "generator.percent_test_set = 0.2\n",
    "generator.generateInput({\n",
    "    'did_grasp_weight': 0.13,\n",
    "    'force_rewrite': True,\n",
    "    'size_cropped': (240, 240),\n",
    "    'size_input': (752, 480),\n",
    "    'size_output': (32, 32),\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 166,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Length: 463\n",
      "Reward Mean: 0.0414549364387061\n",
      "Length: 1931\n",
      "Reward Mean: 0.03356556036894687\n",
      "0.04424783875434233\n",
      "0.520727468219353 0.11166881158867335\n"
     ]
    }
   ],
   "source": [
    "from tensorflow_loader import DataLoader\n",
    "\n",
    "test_loader = DataLoader(generator.test_output_filename, label_fields=['reward', 'action_direction'])\n",
    "train_loader = DataLoader(generator.train_output_filename, label_fields=['reward', 'action_direction'], batch_size=128)\n",
    "\n",
    "test_loader.labels[:, 0] = 0.5 * (test_loader.labels[:, 0] + 1)\n",
    "train_loader.labels[:, 0] = 0.5 * (train_loader.labels[:, 0] + 1)\n",
    "\n",
    "reward_abs_mean = abs(test_loader.labels[:, 0] - 0.5).mean()\n",
    "print(reward_abs_mean)\n",
    "\n",
    "print(test_loader.labels[:, 0].mean(), test_loader.labels[:, 0].std())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 175,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow_utils import get_augmentation, single_class_split, tf_accuracy, tf_precision, tf_recall, tf_f1\n",
    "\n",
    "tf.reset_default_graph()\n",
    "tf.set_random_seed(4)\n",
    "\n",
    "\n",
    "train_dataset = tf.data.Dataset.from_generator(train_loader.nextBatch, (tf.float32, tf.float32), ((None, None, None, 1), (None, 2)))\n",
    "test_dataset = tf.data.Dataset.from_generator(test_loader.entireBatch, (tf.float32, tf.float32))\n",
    "\n",
    "\n",
    "handle = tf.placeholder(tf.string, shape=(), name='handle')\n",
    "iterator = tf.data.Iterator.from_string_handle(handle, train_dataset.output_types, train_dataset.output_shapes)\n",
    "train_iterator = train_dataset.make_initializable_iterator()\n",
    "test_iterator = test_dataset.make_initializable_iterator()\n",
    "    \n",
    "image, label = iterator.get_next()\n",
    "\n",
    "\n",
    "def act(x): return tf.nn.leaky_relu(x, 0.2)\n",
    "def reg(l1=0.0, l2=0.1): return tf.contrib.layers.l1_l2_regularizer(l1, l2)\n",
    "    \n",
    "    \n",
    "training = tf.placeholder_with_default(False, (), name='training')\n",
    "apply_dropout = tf.placeholder_with_default(training, (), name='apply_dropout')\n",
    "image = tf.identity(image, name='image')\n",
    "label = tf.identity(label, name='label')\n",
    "\n",
    "    \n",
    "x = tf.layers.conv2d(image, 32, (5, 5), strides=(2, 2), activation=act, kernel_regularizer=reg(), bias_regularizer=reg())\n",
    "x = tf.layers.batch_normalization(x, training=training)\n",
    "x = tf.layers.dropout(x, 0.2, training=apply_dropout)\n",
    "    \n",
    "x = tf.layers.conv2d(x, 48, (5, 5), strides=(1, 1), activation=act, kernel_regularizer=reg(), bias_regularizer=reg())\n",
    "x = tf.layers.batch_normalization(x, training=training)\n",
    "x = tf.layers.dropout(x, 0.3, training=apply_dropout)\n",
    "\n",
    "x = tf.layers.conv2d(x, 64, (5, 5), activation=act, kernel_regularizer=reg(), bias_regularizer=reg())\n",
    "x = tf.layers.batch_normalization(x, training=training)\n",
    "x = tf.layers.dropout(x, 0.4, training=apply_dropout)\n",
    "\n",
    "x = tf.layers.conv2d(x, 142, (6, 6), activation=act, kernel_regularizer=reg(), bias_regularizer=reg())\n",
    "x = tf.layers.batch_normalization(x, training=training)\n",
    "x = tf.layers.dropout(x, 0.4, training=apply_dropout)\n",
    "    \n",
    "x = tf.layers.conv2d(x, 128, (1, 1), activation=act, kernel_regularizer=reg(), bias_regularizer=reg())\n",
    "x = tf.layers.dropout(x, 0.4, training=apply_dropout)\n",
    "    \n",
    "logits = tf.layers.conv2d(x, 2, (1, 1), bias_regularizer=reg(l2=0.05))\n",
    "prob = tf.nn.sigmoid(logits, name='prob')\n",
    "\n",
    "\n",
    "reward, reward_pred = single_class_split(label, prob[:, 0, 0])\n",
    "\n",
    "\n",
    "weights = (0.2 + tf.abs(reward - 0.5)) / (0.2 + reward_abs_mean)\n",
    "loss = tf.losses.mean_squared_error(labels=reward, predictions=reward_pred, weights=weights)\n",
    "loss_reg = loss + tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))\n",
    "loss2 = tf.losses.mean_squared_error(labels=reward, predictions=reward_pred)\n",
    "\n",
    "accuracy = tf_accuracy(tf.round(reward), reward_pred)\n",
    "precision = tf_precision(tf.round(reward), reward_pred)\n",
    "recall = tf_recall(tf.round(reward), reward_pred)\n",
    "f1 = tf_f1(precision, recall, beta=0.5)\n",
    "\n",
    "\n",
    "extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)  # For batch normalization layers\n",
    "with tf.control_dependencies(extra_update_ops):\n",
    "    train = tf.train.AdamOptimizer(learning_rate=1e-5).minimize(loss_reg, name='train')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 176,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from /home/berscheid/Documents/data/pushing/models/model-2\n",
      "Epoch\tTime[s]\tTrain Accuracy\tTest Loss\tAccuracy\tPrecision\tRecall\t\tF1\n",
      "   -1\t0.00\t-1.0000\t\t0.0103\t\t0.0103\t\t0.5140\t\t0.7173\t\t0.0067\n",
      "    0\t0.46\t0.0025\t\t0.0103\t\t0.0103\t\t0.5076\t\t0.7147\t\t0.0068\n",
      "    1\t0.57\t0.0014\t\t0.0103\t\t0.0103\t\t0.5119\t\t0.7138\t\t0.0068\n",
      "    2\t0.69\t0.0027\t\t0.0104\t\t0.0104\t\t0.5140\t\t0.7134\t\t0.0068\n",
      "    3\t0.81\t0.0038\t\t0.0104\t\t0.0104\t\t0.5184\t\t0.7131\t\t0.0069\n",
      "    4\t0.93\t0.0024\t\t0.0104\t\t0.0104\t\t0.5162\t\t0.7129\t\t0.0069\n",
      "    5\t1.05\t0.0028\t\t0.0104\t\t0.0104\t\t0.5140\t\t0.7128\t\t0.0069\n",
      "    6\t1.17\t0.0028\t\t0.0104\t\t0.0104\t\t0.5119\t\t0.7127\t\t0.0069\n",
      "    7\t1.29\t0.0028\t\t0.0104\t\t0.0104\t\t0.5119\t\t0.7132\t\t0.0069\n",
      "    8\t1.42\t0.0034\t\t0.0104\t\t0.0104\t\t0.5140\t\t0.7147\t\t0.0069\n",
      "    9\t1.54\t0.0019\t\t0.0104\t\t0.0104\t\t0.5162\t\t0.7158\t\t0.0070\n",
      "   10\t1.66\t0.0025\t\t0.0104\t\t0.0104\t\t0.5140\t\t0.7168\t\t0.0070\n",
      "   11\t1.78\t0.0021\t\t0.0104\t\t0.0104\t\t0.5119\t\t0.7173\t\t0.0070\n",
      "   12\t1.89\t0.0024\t\t0.0104\t\t0.0104\t\t0.5162\t\t0.7180\t\t0.0070\n",
      "   13\t2.01\t0.0027\t\t0.0104\t\t0.0104\t\t0.5184\t\t0.7183\t\t0.0070\n",
      "   14\t2.13\t0.0020\t\t0.0104\t\t0.0104\t\t0.5205\t\t0.7186\t\t0.0070\n",
      "   15\t2.25\t0.0022\t\t0.0104\t\t0.0104\t\t0.5227\t\t0.7191\t\t0.0070\n",
      "   16\t2.37\t0.0020\t\t0.0104\t\t0.0104\t\t0.5205\t\t0.7193\t\t0.0070\n",
      "   17\t2.49\t0.0022\t\t0.0104\t\t0.0104\t\t0.5205\t\t0.7195\t\t0.0070\n",
      "   18\t2.61\t0.0022\t\t0.0104\t\t0.0104\t\t0.5227\t\t0.7196\t\t0.0070\n",
      "   19\t2.73\t0.0023\t\t0.0104\t\t0.0104\t\t0.5227\t\t0.7198\t\t0.0070\n",
      "   20\t2.85\t0.0034\t\t0.0104\t\t0.0104\t\t0.5270\t\t0.7201\t\t0.0070\n",
      "   21\t2.97\t0.0034\t\t0.0104\t\t0.0104\t\t0.5292\t\t0.7205\t\t0.0070\n",
      "   22\t3.08\t0.0016\t\t0.0104\t\t0.0104\t\t0.5313\t\t0.7208\t\t0.0070\n",
      "   23\t3.20\t0.0025\t\t0.0104\t\t0.0104\t\t0.5313\t\t0.7210\t\t0.0070\n",
      "   24\t3.32\t0.0034\t\t0.0104\t\t0.0104\t\t0.5356\t\t0.7213\t\t0.0070\n",
      "   25\t3.44\t0.0038\t\t0.0104\t\t0.0104\t\t0.5335\t\t0.7217\t\t0.0070\n",
      "   26\t3.57\t0.0025\t\t0.0104\t\t0.0104\t\t0.5313\t\t0.7221\t\t0.0070\n",
      "   27\t3.69\t0.0021\t\t0.0104\t\t0.0104\t\t0.5356\t\t0.7225\t\t0.0070\n",
      "   28\t3.82\t0.0033\t\t0.0104\t\t0.0104\t\t0.5378\t\t0.7229\t\t0.0070\n",
      "   29\t3.93\t0.0038\t\t0.0104\t\t0.0104\t\t0.5356\t\t0.7232\t\t0.0070\n",
      "   30\t4.05\t0.0043\t\t0.0104\t\t0.0104\t\t0.5356\t\t0.7235\t\t0.0070\n",
      "   31\t4.16\t0.0021\t\t0.0104\t\t0.0104\t\t0.5378\t\t0.7238\t\t0.0070\n",
      "   32\t4.28\t0.0030\t\t0.0104\t\t0.0104\t\t0.5356\t\t0.7241\t\t0.0069\n",
      "   33\t4.40\t0.0028\t\t0.0104\t\t0.0104\t\t0.5378\t\t0.7243\t\t0.0069\n",
      "   34\t4.52\t0.0028\t\t0.0104\t\t0.0104\t\t0.5356\t\t0.7245\t\t0.0070\n",
      "   35\t4.63\t0.0021\t\t0.0104\t\t0.0104\t\t0.5356\t\t0.7248\t\t0.0069\n",
      "   36\t4.75\t0.0029\t\t0.0104\t\t0.0104\t\t0.5378\t\t0.7250\t\t0.0069\n",
      "   37\t4.87\t0.0030\t\t0.0104\t\t0.0104\t\t0.5378\t\t0.7252\t\t0.0069\n",
      "   38\t4.99\t0.0035\t\t0.0104\t\t0.0104\t\t0.5313\t\t0.7253\t\t0.0069\n",
      "   39\t5.11\t0.0031\t\t0.0104\t\t0.0104\t\t0.5356\t\t0.7253\t\t0.0069\n",
      "   40\t5.23\t0.0021\t\t0.0104\t\t0.0104\t\t0.5356\t\t0.7254\t\t0.0069\n",
      "   41\t5.35\t0.0033\t\t0.0104\t\t0.0104\t\t0.5400\t\t0.7254\t\t0.0069\n",
      "   42\t5.47\t0.0026\t\t0.0104\t\t0.0104\t\t0.5378\t\t0.7254\t\t0.0069\n",
      "   43\t5.59\t0.0025\t\t0.0104\t\t0.0104\t\t0.5356\t\t0.7252\t\t0.0069\n",
      "   44\t5.71\t0.0036\t\t0.0104\t\t0.0104\t\t0.5400\t\t0.7250\t\t0.0069\n",
      "   45\t5.83\t0.0026\t\t0.0104\t\t0.0104\t\t0.5464\t\t0.7250\t\t0.0069\n",
      "   46\t5.95\t0.0032\t\t0.0104\t\t0.0104\t\t0.5464\t\t0.7249\t\t0.0069\n",
      "   47\t6.07\t0.0029\t\t0.0104\t\t0.0104\t\t0.5486\t\t0.7249\t\t0.0069\n",
      "   48\t6.19\t0.0032\t\t0.0104\t\t0.0104\t\t0.5486\t\t0.7248\t\t0.0069\n",
      "   49\t6.31\t0.0021\t\t0.0104\t\t0.0104\t\t0.5400\t\t0.7245\t\t0.0069\n",
      "   50\t6.43\t0.0026\t\t0.0104\t\t0.0104\t\t0.5400\t\t0.7241\t\t0.0069\n",
      "   51\t6.55\t0.0048\t\t0.0104\t\t0.0104\t\t0.5400\t\t0.7238\t\t0.0069\n",
      "   52\t6.66\t0.0025\t\t0.0104\t\t0.0104\t\t0.5464\t\t0.7237\t\t0.0069\n",
      "   53\t6.79\t0.0031\t\t0.0104\t\t0.0104\t\t0.5421\t\t0.7235\t\t0.0069\n",
      "   54\t6.90\t0.0027\t\t0.0104\t\t0.0104\t\t0.5443\t\t0.7233\t\t0.0069\n",
      "   55\t7.02\t0.0024\t\t0.0104\t\t0.0104\t\t0.5443\t\t0.7230\t\t0.0069\n",
      "   56\t7.15\t0.0022\t\t0.0104\t\t0.0104\t\t0.5443\t\t0.7227\t\t0.0069\n",
      "   57\t7.27\t0.0030\t\t0.0104\t\t0.0104\t\t0.5486\t\t0.7226\t\t0.0069\n",
      "   58\t7.39\t0.0028\t\t0.0104\t\t0.0104\t\t0.5464\t\t0.7223\t\t0.0069\n",
      "   59\t7.51\t0.0019\t\t0.0104\t\t0.0104\t\t0.5464\t\t0.7221\t\t0.0069\n",
      "   60\t7.63\t0.0040\t\t0.0104\t\t0.0104\t\t0.5464\t\t0.7218\t\t0.0069\n",
      "   61\t7.75\t0.0032\t\t0.0104\t\t0.0104\t\t0.5464\t\t0.7216\t\t0.0068\n",
      "   62\t7.87\t0.0045\t\t0.0104\t\t0.0104\t\t0.5486\t\t0.7214\t\t0.0069\n",
      "   63\t7.98\t0.0015\t\t0.0104\t\t0.0104\t\t0.5508\t\t0.7212\t\t0.0068\n",
      "   64\t8.10\t0.0024\t\t0.0104\t\t0.0104\t\t0.5508\t\t0.7211\t\t0.0068\n",
      "   65\t8.22\t0.0018\t\t0.0104\t\t0.0104\t\t0.5508\t\t0.7210\t\t0.0068\n",
      "   66\t8.34\t0.0020\t\t0.0104\t\t0.0104\t\t0.5508\t\t0.7208\t\t0.0068\n",
      "   67\t8.46\t0.0025\t\t0.0104\t\t0.0104\t\t0.5508\t\t0.7207\t\t0.0068\n",
      "   68\t8.57\t0.0019\t\t0.0104\t\t0.0104\t\t0.5508\t\t0.7206\t\t0.0068\n",
      "   69\t8.69\t0.0034\t\t0.0104\t\t0.0104\t\t0.5529\t\t0.7204\t\t0.0068\n",
      "   70\t8.81\t0.0020\t\t0.0104\t\t0.0104\t\t0.5529\t\t0.7203\t\t0.0068\n",
      "   71\t8.93\t0.0039\t\t0.0104\t\t0.0104\t\t0.5508\t\t0.7201\t\t0.0068\n",
      "   72\t9.04\t0.0024\t\t0.0104\t\t0.0104\t\t0.5529\t\t0.7200\t\t0.0068\n",
      "   73\t9.16\t0.0023\t\t0.0104\t\t0.0104\t\t0.5529\t\t0.7198\t\t0.0068\n",
      "   74\t9.27\t0.0043\t\t0.0104\t\t0.0104\t\t0.5572\t\t0.7196\t\t0.0068\n",
      "   75\t9.38\t0.0031\t\t0.0104\t\t0.0104\t\t0.5551\t\t0.7195\t\t0.0068\n",
      "   76\t9.50\t0.0018\t\t0.0104\t\t0.0104\t\t0.5551\t\t0.7193\t\t0.0068\n",
      "   77\t9.61\t0.0026\t\t0.0104\t\t0.0104\t\t0.5572\t\t0.7191\t\t0.0068\n",
      "   78\t9.73\t0.0019\t\t0.0104\t\t0.0104\t\t0.5594\t\t0.7190\t\t0.0068\n",
      "   79\t9.84\t0.0024\t\t0.0104\t\t0.0104\t\t0.5594\t\t0.7188\t\t0.0068\n",
      "   80\t9.96\t0.0028\t\t0.0104\t\t0.0104\t\t0.5616\t\t0.7186\t\t0.0068\n",
      "   81\t10.08\t0.0022\t\t0.0104\t\t0.0104\t\t0.5616\t\t0.7183\t\t0.0068\n",
      "   82\t10.19\t0.0028\t\t0.0104\t\t0.0104\t\t0.5616\t\t0.7182\t\t0.0068\n",
      "   83\t10.31\t0.0025\t\t0.0104\t\t0.0104\t\t0.5594\t\t0.7179\t\t0.0068\n",
      "   84\t10.43\t0.0024\t\t0.0104\t\t0.0104\t\t0.5594\t\t0.7177\t\t0.0068\n",
      "   85\t10.54\t0.0025\t\t0.0104\t\t0.0104\t\t0.5594\t\t0.7175\t\t0.0068\n",
      "   86\t10.66\t0.0037\t\t0.0104\t\t0.0104\t\t0.5572\t\t0.7173\t\t0.0068\n",
      "   87\t10.77\t0.0019\t\t0.0104\t\t0.0104\t\t0.5594\t\t0.7170\t\t0.0068\n",
      "   88\t10.89\t0.0025\t\t0.0104\t\t0.0104\t\t0.5594\t\t0.7168\t\t0.0068\n",
      "   89\t11.01\t0.0048\t\t0.0104\t\t0.0104\t\t0.5594\t\t0.7166\t\t0.0068\n",
      "   90\t11.13\t0.0026\t\t0.0104\t\t0.0104\t\t0.5594\t\t0.7164\t\t0.0068\n",
      "   91\t11.25\t0.0020\t\t0.0104\t\t0.0104\t\t0.5594\t\t0.7162\t\t0.0068\n",
      "   92\t11.37\t0.0033\t\t0.0104\t\t0.0104\t\t0.5572\t\t0.7159\t\t0.0068\n",
      "   93\t11.49\t0.0027\t\t0.0104\t\t0.0104\t\t0.5572\t\t0.7157\t\t0.0068\n",
      "   94\t11.61\t0.0035\t\t0.0104\t\t0.0104\t\t0.5572\t\t0.7155\t\t0.0068\n",
      "   95\t11.73\t0.0021\t\t0.0104\t\t0.0104\t\t0.5594\t\t0.7153\t\t0.0068\n",
      "   96\t11.86\t0.0026\t\t0.0104\t\t0.0104\t\t0.5616\t\t0.7151\t\t0.0068\n",
      "   97\t11.98\t0.0025\t\t0.0104\t\t0.0104\t\t0.5594\t\t0.7150\t\t0.0068\n",
      "   98\t12.10\t0.0019\t\t0.0104\t\t0.0104\t\t0.5594\t\t0.7148\t\t0.0068\n",
      "   99\t12.22\t0.0040\t\t0.0104\t\t0.0104\t\t0.5572\t\t0.7146\t\t0.0067\n",
      "  100\t12.34\t0.0029\t\t0.0104\t\t0.0104\t\t0.5572\t\t0.7144\t\t0.0067\n",
      "  101\t12.46\t0.0028\t\t0.0104\t\t0.0104\t\t0.5572\t\t0.7142\t\t0.0067\n",
      "  102\t12.58\t0.0024\t\t0.0104\t\t0.0104\t\t0.5594\t\t0.7141\t\t0.0067\n",
      "  103\t12.70\t0.0032\t\t0.0104\t\t0.0104\t\t0.5572\t\t0.7139\t\t0.0067\n",
      "  104\t12.83\t0.0020\t\t0.0104\t\t0.0104\t\t0.5594\t\t0.7137\t\t0.0067\n",
      "  105\t12.94\t0.0027\t\t0.0104\t\t0.0104\t\t0.5594\t\t0.7136\t\t0.0067\n",
      "  106\t13.06\t0.0036\t\t0.0104\t\t0.0104\t\t0.5594\t\t0.7134\t\t0.0067\n",
      "  107\t13.18\t0.0036\t\t0.0104\t\t0.0104\t\t0.5594\t\t0.7132\t\t0.0067\n",
      "  108\t13.30\t0.0032\t\t0.0104\t\t0.0104\t\t0.5616\t\t0.7131\t\t0.0067\n",
      "  109\t13.42\t0.0020\t\t0.0104\t\t0.0104\t\t0.5616\t\t0.7129\t\t0.0067\n",
      "  110\t13.54\t0.0037\t\t0.0104\t\t0.0104\t\t0.5594\t\t0.7126\t\t0.0067\n",
      "  111\t13.66\t0.0026\t\t0.0104\t\t0.0104\t\t0.5594\t\t0.7124\t\t0.0067\n",
      "  112\t13.79\t0.0027\t\t0.0104\t\t0.0104\t\t0.5551\t\t0.7121\t\t0.0067\n",
      "  113\t13.90\t0.0047\t\t0.0104\t\t0.0104\t\t0.5572\t\t0.7119\t\t0.0067\n",
      "  114\t14.02\t0.0033\t\t0.0104\t\t0.0104\t\t0.5594\t\t0.7117\t\t0.0067\n",
      "  115\t14.13\t0.0043\t\t0.0104\t\t0.0104\t\t0.5594\t\t0.7115\t\t0.0067\n",
      "  116\t14.25\t0.0027\t\t0.0104\t\t0.0104\t\t0.5637\t\t0.7112\t\t0.0067\n",
      "  117\t14.37\t0.0039\t\t0.0104\t\t0.0104\t\t0.5637\t\t0.7110\t\t0.0067\n",
      "  118\t14.50\t0.0029\t\t0.0104\t\t0.0104\t\t0.5659\t\t0.7108\t\t0.0067\n",
      "  119\t14.62\t0.0025\t\t0.0104\t\t0.0104\t\t0.5637\t\t0.7106\t\t0.0067\n",
      "  120\t14.74\t0.0041\t\t0.0104\t\t0.0104\t\t0.5616\t\t0.7104\t\t0.0067\n",
      "  121\t14.85\t0.0030\t\t0.0104\t\t0.0104\t\t0.5637\t\t0.7102\t\t0.0067\n",
      "  122\t14.97\t0.0015\t\t0.0104\t\t0.0104\t\t0.5680\t\t0.7100\t\t0.0067\n",
      "  123\t15.08\t0.0025\t\t0.0104\t\t0.0104\t\t0.5637\t\t0.7098\t\t0.0067\n",
      "  124\t15.20\t0.0027\t\t0.0104\t\t0.0104\t\t0.5594\t\t0.7095\t\t0.0067\n",
      "  125\t15.32\t0.0037\t\t0.0104\t\t0.0104\t\t0.5637\t\t0.7093\t\t0.0067\n",
      "  126\t15.43\t0.0027\t\t0.0104\t\t0.0104\t\t0.5616\t\t0.7091\t\t0.0067\n",
      "  127\t15.54\t0.0018\t\t0.0104\t\t0.0104\t\t0.5637\t\t0.7088\t\t0.0067\n",
      "  128\t15.66\t0.0044\t\t0.0104\t\t0.0104\t\t0.5680\t\t0.7087\t\t0.0067\n",
      "  129\t15.78\t0.0029\t\t0.0104\t\t0.0104\t\t0.5680\t\t0.7085\t\t0.0067\n",
      "  130\t15.90\t0.0037\t\t0.0104\t\t0.0104\t\t0.5680\t\t0.7083\t\t0.0067\n",
      "  131\t16.02\t0.0033\t\t0.0104\t\t0.0104\t\t0.5680\t\t0.7081\t\t0.0067\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-176-d7885c475050>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     15\u001b[0m     \u001b[0mload\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m     \u001b[0msave\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m     \u001b[0mexport\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     18\u001b[0m )\n",
      "\u001b[0;32m~/Documents/bin_picking/jupyter/tensorflow_model.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, train_iterator, train_batches_per_epoch, test_iterator, epochs, early_stopping_patience, load, save, export)\u001b[0m\n\u001b[1;32m    108\u001b[0m             \u001b[0;32mfor\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_batches_per_epoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    109\u001b[0m \u001b[0;31m#            for batch in tqdm(range(train_batches_per_epoch), unit='batches', leave=False, dynamic_ncols=True):\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 110\u001b[0;31m                 \u001b[0mtrain_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtest_metrices\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m'handle:0'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_handle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'training:0'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    111\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    112\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0monEpochEnd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mearly_stopping_patience\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msave\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m    875\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    876\u001b[0m       result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 877\u001b[0;31m                          run_metadata_ptr)\n\u001b[0m\u001b[1;32m    878\u001b[0m       \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    879\u001b[0m         \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m   1098\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mhandle\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfeed_dict_tensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1099\u001b[0m       results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m-> 1100\u001b[0;31m                              feed_dict_tensor, options, run_metadata)\n\u001b[0m\u001b[1;32m   1101\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1102\u001b[0m       \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m   1270\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1271\u001b[0m       return self._do_call(_run_fn, feeds, fetches, targets, options,\n\u001b[0;32m-> 1272\u001b[0;31m                            run_metadata)\n\u001b[0m\u001b[1;32m   1273\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1274\u001b[0m       \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_prun_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeeds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetches\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m   1276\u001b[0m   \u001b[0;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1277\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1278\u001b[0;31m       \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1279\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1280\u001b[0m       \u001b[0mmessage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[0;34m(feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[1;32m   1261\u001b[0m       \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_extend_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1262\u001b[0m       return self._call_tf_sessionrun(\n\u001b[0;32m-> 1263\u001b[0;31m           options, feed_dict, fetch_list, target_list, run_metadata)\n\u001b[0m\u001b[1;32m   1264\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1265\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_call_tf_sessionrun\u001b[0;34m(self, options, feed_dict, fetch_list, target_list, run_metadata)\u001b[0m\n\u001b[1;32m   1348\u001b[0m     return tf_session.TF_SessionRun_wrapper(\n\u001b[1;32m   1349\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1350\u001b[0;31m         run_metadata)\n\u001b[0m\u001b[1;32m   1351\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1352\u001b[0m   \u001b[0;32mdef\u001b[0m \u001b[0m_call_tf_sessionprun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "from tensorflow_model import Model\n",
    "\n",
    "model = Model(\n",
    "    test_metrices=[loss, loss, accuracy, recall, loss2],\n",
    "    inputs={'image': image, 'apply_dropout': apply_dropout},\n",
    "    outputs={'prob': prob},\n",
    "    model_input_path=generator.model_directory + 'model-2',\n",
    "    model_output_path=generator.model_directory + 'model-2'\n",
    ")\n",
    "\n",
    "model.fit(\n",
    "    train_iterator, train_loader.batches_per_epoch, test_iterator,\n",
    "    epochs=5000,\n",
    "    early_stopping_patience=3000,\n",
    "    load=True,\n",
    "    save=True,\n",
    "    export=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
